-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] RLModule API: SelfSupervisedLossAPI
for RLModules that bring their own loss (algo independent).
#47581
[RLlib] RLModule API: SelfSupervisedLossAPI
for RLModules that bring their own loss (algo independent).
#47581
Conversation
…odule_api_self_supervised_loss
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
rllib/algorithms/ppo/ppo_learner.py
Outdated
|
||
@abc.abstractmethod | ||
def _update_module_kl_coeff( | ||
self, | ||
*, | ||
module_id: ModuleID, | ||
config: PPOConfig, | ||
kl_loss: float, | ||
) -> None: | ||
"""Dynamically update the KL loss coefficients of each module with. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"module with"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
learner_config_dict={ | ||
# Intrinsic reward coefficient. | ||
"intrinsic_reward_coeff": 0.05, | ||
# Forward loss weight (vs inverse dynamics loss). Total ICM loss is: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice comment!
|
||
|
||
class DQNTorchLearnerWithCuriosity(DQNRainbowTorchLearner): | ||
def build(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dumb question: Can't we just override AlgorithmConfig.build_learner_pipeline()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could, but again, this should be done inside the Learner, imo.
But you have a good point: How can we make this even easier for the user? Maybe offer a better way to customize the Learner pipeline? Currently, users can only prepend connector pieces to the beginning, then RLlib adds the default pieces to the end. But here, we need a (custom) connector piece to move all the way to the end, which is not possible with the config.learner_connector
property.
], | ||
dim=0, | ||
) | ||
obs = tree.map_structure( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice!
*, | ||
learner: "TorchLearner", | ||
module_id: ModuleID, | ||
config: "AlgorithmConfig", | ||
batch: Dict[str, Any], | ||
fwd_out: Dict[str, Any], | ||
) -> Dict[str, Any]: | ||
module = learner.module[module_id] | ||
module = learner.module[module_id].unwrapped() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we need this for DDP?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct
@staticmethod | ||
def compute_loss_for_module( | ||
@override(SelfSupervisedLossAPI) | ||
def compute_self_supervised_loss( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What somehow irritates me is that we are putting the loss function into the module, but still build a special learner to handle this. Instead we could directly override the learners compute_loss_for_module
, couldn't we?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my answer to your comment above.
…odule_api_self_supervised_loss Signed-off-by: sven1977 <svenmika1977@gmail.com> # Conflicts: # rllib/core/rl_module/apis/__init__.py
…odule_api_self_supervised_loss
…g their own loss (algo independent). (ray-project#47581) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
…g their own loss (algo independent). (ray-project#47581) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
…g their own loss (algo independent). (ray-project#47581) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
…g their own loss (algo independent). (ray-project#47581) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
…g their own loss (algo independent). (ray-project#47581) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
…g their own loss (algo independent). (ray-project#47581) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
…g their own loss (algo independent). (ray-project#47581) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
…g their own loss (algo independent). (ray-project#47581) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
…g their own loss (algo independent). (ray-project#47581) Signed-off-by: ujjawal-khare <ujjawal.khare@dream11.com>
RLModule API:
SelfSupervisedLossAPI
for RLModules that bring their own loss (algo independent).Learner
now checks whether any RLModule (in MultiRLModule) implements this API and if yes, calls the Module's owncompute_self_supervised_loss
method (instead of the Learner'scompute_loss_for_module()
method).Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.